""" build the census region shares for the predocs
"""

import os
import sys
import datetime
import importlib.util
import subprocess
import psutil
from typing import Dict

# Import polars
try:
    import polars as pl
except ModuleNotFoundError as e:
    print(f"could not find module: {e}, installing")
    child = psutil.Process(os.getpid())
    executable = child.cmdline()[0]
    output = \
        subprocess.check_output(
            [sys.executable, "-m", "pip", "install", "polars==0.18.1", "typing_extensions==4.6.3"],
            stderr=subprocess.STDOUT
            )
    print(f"output of installation: {output}")

    spec = importlib.util.spec_from_file_location("typing_extensions", "/home/runners/.local/lib/python3.8/site-packages/typing_extensions.py")
    typing_extensions = importlib.util.module_from_spec(spec)
    sys.modules["typing_extensions"] = typing_extensions
    spec.loader.exec_module(typing_extensions)

    spec = importlib.util.spec_from_file_location("polars", "/home/runners/.local/lib/python3.8/site-packages/polars/__init__.py")
    pl = importlib.util.module_from_spec(spec)
    sys.modules["polars"] = pl
    spec.loader.exec_module(pl)
    print(f"installation finished")

# Import pandas
try:
    import pandas as pd
except ModuleNotFoundError as e:
    print(f"could not find module: {e}, installing")
    output = \
    subprocess.check_output(
        [sys.executable, "-m", "pip", "install", "et_xmlfile==1.1.0", "openpyxl==3.1.2", "numpy==1.24.3", "pandas==2.0.2"],
        stderr=subprocess.STDOUT
        )
    print(f"output of installation: {output}")

    # Import dependencies
    spec_et = importlib.util.spec_from_file_location("et_xmlfile", "/home/runners/.local/lib/python3.8/site-packages/et_xmlfile/__init__.py")
    et = importlib.util.module_from_spec(spec_et)
    sys.modules["et_xmlfile"] = et
    spec_et.loader.exec_module(et)

    spec_xl = importlib.util.spec_from_file_location("openpyxl", "/home/runners/.local/lib/python3.8/site-packages/openpyxl/__init__.py")
    xl = importlib.util.module_from_spec(spec_xl)
    sys.modules["openpyxl"] = xl
    spec_xl.loader.exec_module(xl)

    spec_np = importlib.util.spec_from_file_location("numpy", "/home/runners/.local/lib/python3.8/site-packages/numpy/__init__.py")
    np = importlib.util.module_from_spec(spec_np)
    sys.modules["numpy"] = np
    spec_np.loader.exec_module(np)

    # Import pandas
    spec_pd = importlib.util.spec_from_file_location("pandas", "/home/runners/.local/lib/python3.8/site-packages/pandas/__init__.py")
    pd = importlib.util.module_from_spec(spec_pd)
    sys.modules["pandas"] = pd
    spec_pd.loader.exec_module(pd)


def get_zearn_counts(path: str) -> pl.LazyFrame:
    """ import built raw zearn data
    """

    ldf_zearn = pl.scan_csv(path, dtypes = {"date": pl.Date})
    ldf_zearn = ldf_zearn.filter(pl.col("coverage_zip") == 1)
    ldf_zearn = ldf_zearn.filter(pl.col("date") == datetime.date(2020, 1, 27))
    ldf_zearn = ldf_zearn.groupby(["zcta5"]).agg([pl.col("students").sum()])
    ldf_zearn = ldf_zearn.rename({"zcta5": "zipcode"})
    return ldf_zearn


def get_mdr_counts(path: str) -> pl.LazyFrame:
    """ import mdr data for school counts
    """

    ldf_mdr = pl.scan_csv(path)
    ldf_mdr = ldf_mdr.rename({"MDR Schools Zipcode": "zipcode", "MDR Schools All Grades": "students"})
    ldf_mdr = ldf_mdr.groupby(["zipcode"]).agg([pl.col("students").sum()])
    return ldf_mdr


def crosswalk_states(ldf: pl.LazyFrame, dir: str) -> pl.LazyFrame:
    """ croswalk on state values
    """

    ldf_zip_zcta_cw = pl.scan_csv(os.path.join(dir, r"data/dvc/Crosswalks/Zip_to_zcta_crosswalk_2020.csv"), null_values = ["No ZCTA"])
    ldf_zip_zcta_cw = ldf_zip_zcta_cw.rename({"ZIP_CODE": "zipcode", "STATE": "state"})
    ldf_zip_zcta_cw = ldf_zip_zcta_cw.select([pl.col("zipcode"), pl.col("state")])

    join_ldf = ldf.join(ldf_zip_zcta_cw, on = "zipcode", how = "left")
    return join_ldf


def generate_census_regions_dict() -> Dict:
    """ create census region dict
    """

    region_dict = {}
    region_dict["west"] = ["WA", "OR", "CA", "NV", "ID", "MT", "WY", "UT", "CO", "AZ", "NM"]
    region_dict["midwest"] = ["ND", "SD", "NE", "KS", "MN", "IA", "MO", "WI", "IL", "IN", "MI", "OH"]
    region_dict["south"] = ["TX", "OK", "AR", "LA", "MS", "AL", "TN", "KY", "FL", "GA", "SC", "NC", "VA", "WV", "DC", "MD", "DE"]
    region_dict["northeast"] = ["PA", "NJ", "NY", "CT", "RI", "MA", "VT", "NH", "ME"]
    return region_dict


def get_census_region_shares(ldf: pl.LazyFrame) -> Dict:
    """ get the shares of students in each census region
    """

    region_dict = generate_census_regions_dict()
    count_west = ldf.filter(pl.col("state").is_in(region_dict["west"])).select("students").sum().collect().item()
    count_midwest = ldf.filter(pl.col("state").is_in(region_dict["midwest"])).select("students").sum().collect().item()
    count_south = ldf.filter(pl.col("state").is_in(region_dict["south"])).select("students").sum().collect().item()
    count_northeast = ldf.filter(pl.col("state").is_in(region_dict["northeast"])).select("students").sum().collect().item()
    count_total = ldf.select(pl.col("students")).sum().collect().item()

    shares_dict = {}
    shares_dict["west"] = count_west/count_total
    shares_dict["midwest"] = count_midwest/count_total
    shares_dict["south"] = count_south/count_total
    shares_dict["northeast"] = count_northeast/count_total
    return shares_dict


def build_table(dir: str):
    """ build the table of census region shares
    """

    ldf_zearn = get_zearn_counts(os.path.join(dir, r"data/derived/Zearn/intermediate/zearn_table_data.csv"))
    ldf_zearn = crosswalk_states(ldf_zearn, dir)
    zearn_shares = get_census_region_shares(ldf_zearn)

    ldf_mdr = get_mdr_counts(os.path.join(dir, r"data/dvc/Zearn/other/School demos_6.1.20_Added detail.csv"))
    ldf_mdr = crosswalk_states(ldf_mdr, dir)
    mdr_shares = get_census_region_shares(ldf_mdr)

    # pandas, not polars, dataframe
    # lists need to be in this particular order ("Midwest", "Northeast", "South", "West") for table formatting
    df = pd.DataFrame({
        "region": ["Midwest", "Northeast", "South", "West"],
        "zearn_shares": [zearn_shares["midwest"], zearn_shares["northeast"], zearn_shares["south"], zearn_shares["west"]],
        "mdr_shares": [mdr_shares["midwest"], mdr_shares["northeast"], mdr_shares["south"], mdr_shares["west"]],
    })

    df.to_excel(os.path.join(dir, "results/new_app_table_8_c.xlsx"), sheet_name = "new_app_table_8_c", header = False, index = False)


if __name__ == '__main__':
    build_table(sys.argv[1])
